/*
 * Copyright (c) 2019-2025, NVIDIA CORPORATION.
 *
 * Copyright 2018-2019 BlazingDB, Inc.
 *     Copyright 2018 Christian Noboa Mardini <christian@blazingdb.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "assert_unary.h"

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/random.hpp>
#include <cudf_test/type_lists.hpp>

#include <cudf/detail/iterator.cuh>
#include <cudf/jit/runtime_support.hpp>
#include <cudf/transform.hpp>

namespace transformation {
struct UnaryOperationIntegrationTest : public cudf::test::BaseFixture {
 protected:
  void SetUp() override
  {
    if (!cudf::is_runtime_jit_supported()) {
      GTEST_SKIP() << "Skipping tests that require runtime JIT support";
    }
  }
};

template <class dtype, class Op, class Data>
void test_udf(char const* udf, Op op, Data data_init, cudf::size_type size, bool is_ptx)
{
  auto all_valid = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return true; });
  auto data_iter = cudf::detail::make_counting_transform_iterator(0, data_init);

  cudf::test::fixed_width_column_wrapper<dtype, typename decltype(data_iter)::value_type> in(
    data_iter, data_iter + size, all_valid);

  std::unique_ptr<cudf::column> out =
    cudf::transform({in}, udf, cudf::data_type(cudf::type_to_id<dtype>()), is_ptx);

  ASSERT_UNARY<dtype, dtype>(out->view(), in, op);
}

TEST_F(UnaryOperationIntegrationTest, Transform_FP32_FP32)
{
  // c = a*a*a*a
  std::string const cuda =
    R"***(
__device__ inline void    fdsf   (
       float* C,
       float a
)
{
  *C = a*a*a*a;
}
)***";

  std::string const ptx =
    R"***(
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//

.version 6.3
.target sm_70
.address_size 64

	// .globl	_ZN8__main__7add$241Ef
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__7add$241Ef;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets7numbers14int_power_impl12$3clocals$3e13int_power$242Efx;

.visible .func  (.param .b32 func_retval0) _ZN8__main__7add$241Ef(
	.param .b64 _ZN8__main__7add$241Ef_param_0,
	.param .b32 _ZN8__main__7add$241Ef_param_1
)
{
	.reg .f32 	%f<4>;
	.reg .b32 	%r<2>;
	.reg .b64 	%rd<2>;


	ld.param.u64 	%rd1, [_ZN8__main__7add$241Ef_param_0];
	ld.param.f32 	%f1, [_ZN8__main__7add$241Ef_param_1];
	mul.f32 	%f2, %f1, %f1;
	mul.f32 	%f3, %f2, %f2;
	st.f32 	[%rd1], %f3;
	mov.u32 	%r1, 0;
	st.param.b32	[func_retval0+0], %r1;
	ret;
}
)***";

  using dtype    = float;
  auto op        = [](dtype a) { return a * a * a * a; };
  auto data_init = [](cudf::size_type row) { return row % 3; };

  test_udf<dtype>(cuda.c_str(), op, data_init, 500, false);
  test_udf<dtype>(ptx.c_str(), op, data_init, 500, true);
}

TEST_F(UnaryOperationIntegrationTest, Transform_INT32_INT32)
{
  // c = a * a - a
  std::string const cuda =
    "__device__ inline void f(int* output,int input){*output = input*input - input;}";

  std::string const ptx =
    R"***(
.func _Z1fPii(
        .param .b64 _Z1fPii_param_0,
        .param .b32 _Z1fPii_param_1
)
{
        .reg .b32       %r<4>;
        .reg .b64       %rd<3>;


        ld.param.u64    %rd1, [_Z1fPii_param_0];
        ld.param.u32    %r1, [_Z1fPii_param_1];
        cvta.to.global.u64      %rd2, %rd1;
        mul.lo.s32      %r2, %r1, %r1;
        sub.s32         %r3, %r2, %r1;
        st.global.u32   [%rd2], %r3;
        ret;
}
)***";

  using dtype    = int;
  auto op        = [](dtype a) { return a * a - a; };
  auto data_init = [](cudf::size_type row) { return row % 78; };

  test_udf<dtype>(cuda.c_str(), op, data_init, 500, false);
  test_udf<dtype>(ptx.c_str(), op, data_init, 500, true);
}

TEST_F(UnaryOperationIntegrationTest, Transform_INT8_INT8)
{
  // Capitalize all the lower case letters
  // Assuming ASCII, the PTX code is compiled from the following CUDA code

  std::string const cuda =
    R"***(
__device__ inline void f(
  signed char* output,
  signed char input
){
	if(input > 96 && input < 123){
  	*output = input - 32;
  }else{
  	*output = input;
  }
}
)***";

  std::string const ptx =
    R"***(
.func _Z1fPcc(
        .param .b64 _Z1fPcc_param_0,
        .param .b32 _Z1fPcc_param_1
)
{
        .reg .pred      %p<2>;
        .reg .b16       %rs<6>;
        .reg .b32       %r<3>;
        .reg .b64       %rd<3>;


        ld.param.u64    %rd1, [_Z1fPcc_param_0];
        cvta.to.global.u64      %rd2, %rd1;
        ld.param.s8     %rs1, [_Z1fPcc_param_1];
        add.s16         %rs2, %rs1, -97;
        and.b16         %rs3, %rs2, 255;
        setp.lt.u16     %p1, %rs3, 26;
        cvt.u32.u16     %r1, %rs1;
        add.s32         %r2, %r1, 224;
        cvt.u16.u32     %rs4, %r2;
        selp.b16        %rs5, %rs4, %rs1, %p1;
        st.global.u8    [%rd2], %rs5;
        ret;
}
)***";

  using dtype    = int8_t;
  auto op        = [](dtype a) { return std::toupper(a); };
  auto data_init = [](cudf::size_type row) { return 'a' + (row % 26); };

  test_udf<dtype>(cuda.c_str(), op, data_init, 500, false);
  test_udf<dtype>(ptx.c_str(), op, data_init, 500, true);
}

TEST_F(UnaryOperationIntegrationTest, Transform_Datetime)
{
  // Add one day to timestamp in microseconds

  std::string const cuda =
    R"***(
__device__ inline void f(cudf::timestamp_us* output, cudf::timestamp_us input)
{
  using dur = cuda::std::chrono::duration<int32_t, cuda::std::ratio<86400>>;
  *output = static_cast<cudf::timestamp_us>(input + dur{1});
}

)***";

  using dtype = cudf::timestamp_us;
  auto op     = [](dtype a) {
    using dur = cuda::std::chrono::duration<int32_t, cuda::std::ratio<86400>>;
    return static_cast<cudf::timestamp_us>(a + dur{1});
  };
  auto random_eng = cudf::test::UniformRandomGenerator<cudf::timestamp_us::rep>(0, 100000000);
  auto data_init  = [&random_eng](cudf::size_type row) { return random_eng.generate(); };

  test_udf<dtype>(cuda.c_str(), op, data_init, 500, false);
}

struct TernaryOperationTest : public cudf::test::BaseFixture {
 protected:
  void SetUp() override
  {
    if (!cudf::is_runtime_jit_supported()) {
      GTEST_SKIP() << "Skipping tests that require runtime JIT support";
    }
  }
};

TEST_F(TernaryOperationTest, TransformWithScalar)
{
  std::string const cuda =
    R"***(
__device__ inline void transform(
       float* out,
       float a,
       float b,
       float c
)
{
  *out = (a + b) * c;
}
)***";

  // Generated from NUMBA, using:
  //
  // ```py
  //
  // from numba import cuda, float32
  // from numba.cuda import compile_ptx_for_current_device
  //
  // # Define a CUDA device function
  //
  // @cuda.jit(device=True)
  // def op(a, b, c):
  //         return (a + b) * c
  //
  // # Define argument types for the function
  // arg_types = (float32, float32, float32)
  //
  // # Compile the device function as relocatable
  // ptx, _ = cuda.compile_ptx_for_current_device(op, arg_types, device=True)
  //
  //
  // # Print the PTX code
  // print("Relocatable PTX Code:")
  // print(ptx)
  //
  //
  // ```
  //
  std::string const ptx =
    R"***(
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-35404655
// Cuda compilation tools, release 12.8, V12.8.61
// Based on NVVM 7.0.1
//

.version 8.7
.target sm_86
.address_size 64

	// .globl	_ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff;

.visible .func  (.param .b32 func_retval0) _ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff(
	.param .b64 _ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_0,
	.param .b32 _ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_1,
	.param .b32 _ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_2,
	.param .b32 _ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_3
)
{
	.reg .f32 	%f<6>;
	.reg .b32 	%r<2>;
	.reg .b64 	%rd<2>;


	ld.param.u64 	%rd1, [_ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_0];
	ld.param.f32 	%f1, [_ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_1];
	ld.param.f32 	%f2, [_ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_2];
	ld.param.f32 	%f3, [_ZN8__main__2opB2v1B96cw51cXTLSUwv1sCUt9Ww0FEw09RRQPKiLTj0gIGIFp_2b2oLQFEYYkHSQB1OQAk0Bynm21OizQ1K0UoIGvDpQE8oxrNQE_3dEfff_param_3];
	add.f32 	%f4, %f1, %f2;
	mul.f32 	%f5, %f4, %f3;
	st.f32 	[%rd1], %f5;
	mov.u32 	%r1, 0;
	st.param.b32 	[func_retval0+0], %r1;
	ret;

}
)***";

  using T = float;

  constexpr T A   = 90;
  constexpr T B   = 100;
  constexpr T C   = 5;
  constexpr T OUT = (A + B) * C;

  std::vector<T> a_host(200, A);
  std::vector<T> b_host(200, B);
  std::vector<T> c_host(1, C);
  std::vector<T> expected_host(200, OUT);

  cudf::test::fixed_width_column_wrapper<T> a(a_host.begin(), a_host.end());
  cudf::test::fixed_width_column_wrapper<T> b(b_host.begin(), b_host.end());
  cudf::test::fixed_width_column_wrapper<T> c(c_host.begin(), c_host.end());
  cudf::test::fixed_width_column_wrapper<T> expected(expected_host.begin(), expected_host.end());

  std::unique_ptr<cudf::column> cuda_result =
    cudf::transform({a, b, c}, cuda, cudf::data_type(cudf::type_to_id<T>()), false);

  CUDF_TEST_EXPECT_COLUMNS_EQUAL(*cuda_result, expected);

  std::unique_ptr<cudf::column> ptx_result =
    cudf::transform({a, b, c}, ptx, cudf::data_type(cudf::type_to_id<T>()), true);

  CUDF_TEST_EXPECT_COLUMNS_EQUAL(*ptx_result, expected);
}

template <typename T>
struct TernaryDecimalOperationTest : public cudf::test::BaseFixture {
 protected:
  void SetUp() override
  {
    if (!cudf::is_runtime_jit_supported()) {
      GTEST_SKIP() << "Skipping tests that require runtime JIT support";
    }
  }
};

TYPED_TEST_SUITE(TernaryDecimalOperationTest, cudf::test::FixedPointTypes);

TYPED_TEST(TernaryDecimalOperationTest, TransformDecimalsAndScalar)
{
  using T = TypeParam;

  auto type_name = cudf::type_to_name(cudf::data_type(cudf::type_to_id<T>()));

  // clang-format off
  std::string const cuda =
    "__device__ void transform("
    + type_name + "* out, "
    + type_name + " a,"
    + type_name + " b,"
    + type_name + " c) {\n"
    + "*out = ((a + b) * c);"
    + " }";
  // clang-format on

  T const A(10, numeric::scale_type{0});
  T const B(20, numeric::scale_type{-1});
  T const C(5, numeric::scale_type{-2});
  T const RES = ((A + B) * C);

  std::vector<typename T::rep> a_host(200, A.value());
  std::vector<typename T::rep> b_host(200, B.value());
  std::vector<typename T::rep> c_host(1, C.value());
  std::vector<typename T::rep> expected_host(200, RES.value());

  cudf::test::fixed_point_column_wrapper<typename T::rep> a(
    a_host.begin(), a_host.end(), A.scale());
  cudf::test::fixed_point_column_wrapper<typename T::rep> b(
    b_host.begin(), b_host.end(), B.scale());
  cudf::test::fixed_point_column_wrapper<typename T::rep> c(
    c_host.begin(), c_host.end(), C.scale());
  cudf::test::fixed_point_column_wrapper<typename T::rep> expected(
    expected_host.begin(), expected_host.end(), RES.scale());

  std::unique_ptr<cudf::column> cuda_result =
    cudf::transform({a, b, c}, cuda, cudf::data_type(cudf::type_to_id<T>(), RES.scale()), false);

  CUDF_TEST_EXPECT_COLUMNS_EQUAL(*cuda_result, expected);
}

}  // namespace transformation
